import warnings
import math
import copy
import numpy as np
import pandas as pd
from pandas.api.types import is_numeric_dtype
from tools.file_utilities import make_folder_if_not_exists
# scipy libraries
import scipy.stats as ss
from scipy.stats import mannwhitneyu, ttest_ind
from scipy.stats.mstats import kruskalwallis
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="white")
sns.set(style="whitegrid", color_codes=True)
plt.rcParams['text.usetex'] = False

def rank_INT(series, c=3.0/8, stochastic=True):
    """ Perform rank-based inverse normal transformation on pandas series.
        If stochastic is True ties are given rank randomly, otherwise ties will
        share the same value. NaN values are ignored.
        Args:
            param1 (pandas.Series):   Series of values to transform
            param2 (Optional[float]): Constand parameter (Bloms constant)
            param3 (Optional[bool]):  Whether to randomise rank of ties

        Returns:
            pandas.Series
    """

    # Check input
    assert(isinstance(series, pd.Series))
    assert(isinstance(c, float))
    assert(isinstance(stochastic, bool))

    # Set seed
    np.random.seed(123)

    # Take original series indexes
    orig_idx = series.index

    # Drop NaNs
    series = series.loc[~pd.isnull(series)]

    # Get ranks
    if stochastic == True:
        # Shuffle by index
        series = series.loc[np.random.permutation(series.index)]
        # Get rank, ties are determined by their position in the series (hence
        # why we randomised the series)
        rank = ss.rankdata(series, method="ordinal")
    else:
        # Get rank, ties are averaged
        rank = ss.rankdata(series, method="average")

    # Convert numpy array back to series
    rank = pd.Series(rank, index=series.index)

    # Convert rank to normal distribution
    transformed = rank.apply(rank_to_normal, c=c, n=len(rank))

    return transformed[orig_idx]

def rank_to_normal(rank, c, n):
    # Standard quantile function
    x = (rank - c) / (n - 2*c + 1)
    return ss.norm.ppf(x)


def subset_by_iqr(df, column, whisker_width=1.5):
    """Remove outliers from a dataframe by column, including optional
       whiskers, removing rows for which the column value are
       less than Q1-1.5IQR or greater than Q3+1.5IQR.
    Args:
        df (`:obj:pd.DataFrame`): A pandas dataframe to subset
        column (str): Name of the column to calculate the subset from.
        whisker_width (float): Optional, loosen the IQR filter by a
                               factor of `whisker_width` * IQR.
    Returns:
        (`:obj:pd.DataFrame`): Filtered dataframe
    """
    # Calculate Q1, Q2 and IQR
    q1 = df[column].quantile(0.25)
    q3 = df[column].quantile(0.75)
    iqr = q3 - q1
    # Apply filter with respect to IQR, including optional whiskers
    filter = (df[column] >= q1 - whisker_width*iqr) & (df[column] <= q3 + whisker_width*iqr)
    return df.loc[filter]


def latexify_p_value(p):
    p_value_string = '$p=%.2E' % p
    p_value_string = p_value_string.replace('E-0', '\\times 10^{-')
    p_value_string = p_value_string.replace('E-', '\\times 10^{-')
    p_value_string = p_value_string.replace('E+', '\\times 10^{')
    p_value_string = p_value_string + '}$'
    return p_value_string

def make_lm_plot(input_table, signature, parameter, title, ylabel, col = None, savepath=None):
    if input_table.empty:
        warnings.warn("Attempting to plot an empty table: %s" % (title))
        return
    output_folder = savepath.rsplit('/',1)[0]
    make_folder_if_not_exists(output_folder)

    table = copy.deepcopy(input_table)
    if not is_numeric_dtype(table[parameter]):
        table[parameter], _ = pd.factorize(table[parameter], sort=True)
    # table[col], _ = pd.factorize(table[col], sort=True)
    # print(table)
    # print(table.info())
    plt.figure(figsize=(10, 7))
    sns.set(font_scale=2)
    # plt.set_ylabel(ylabel, size = 12)
    # plt.title(title, fontsize=12)
    if col:
        sns.lmplot(x=parameter, y=signature, col=col, data=table, x_estimator=np.mean,
               col_wrap=5, aspect=1)
    else:
        sns.lmplot(x=parameter, y=signature, data=table, #x_estimator=np.mean,
               )
    plt.tight_layout()
    if savepath is not None:
        plt.savefig(savepath, transparent=True)
    else:
        plt.show()
    plt.close()
    sns.set(font_scale=1)

def make_OR_plot(input_table, title, log_scale=True, savepath=None):
    if input_table.empty:
        warnings.warn("Attempting to plot an empty table: %s" % (title))
        return
    output_folder = savepath.rsplit('/',1)[0]
    make_folder_if_not_exists(output_folder)

    table = copy.deepcopy(input_table)
    table['2.5%'] = table['OR'] - table['2.5%']
    table['97.5%'] = table['97.5%'] - table['OR']
    plt.errorbar(table['OR'], table.index.to_list(), xerr=table[['2.5%','97.5%']].T.values, fmt = 'o', color = 'k')
    if log_scale:
        plt.xscale('log')
    plt.axvline(1, ls='--')
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.title(title, fontsize=16)
    plt.grid()
    plt.tight_layout()
    if savepath is not None:
        plt.savefig(savepath, transparent=True)
    else:
        plt.show()
    plt.close()

def make_boxplot(input_table, signature, parameter, title, ylabel, show_mean=False, relative=True, verbose=True, add_jitter = False, savepath=None):
    if input_table.empty:
        warnings.warn("Attempting to plot an empty table: %s/%s/%s" % (parameter, signature, title))
        return
    output_folder = savepath.rsplit('/',1)[0]
    make_folder_if_not_exists(output_folder)

    table = copy.deepcopy(input_table)
    parameters = sorted(input_table[parameter].unique()) #, reverse=True)

    f, axes = plt.subplots(1, len(parameters), sharey=True, figsize=(6, 4))
    f.suptitle(title, fontsize=12, y=0.999)
    if len(parameters)==1:
        axes = [axes]

    for parameter_value, axis in zip(parameters, axes):
        sub_dataset = copy.deepcopy(table.loc[table[parameter] == parameter_value])
        sub_dataset[parameter] = parameter_value
        sns.boxplot(x=parameter, y=signature, data=sub_dataset, ax=axis, showfliers = True if not add_jitter else False)

        if verbose:
            median = np.median(sub_dataset[signature].values)
            mean = np.mean(sub_dataset[signature].values)
            print('Sig %s attribution for %s : %f (mean), %f (median)' % (signature, parameter_value, mean, median))

        axis.yaxis.grid(True, linestyle='-', which='major', color='lightgrey', alpha=0.5)

        if add_jitter:
            axis = sns.swarmplot(x=parameter, y=signature, data=sub_dataset, ax=axis, color=".25")
        axis.set(xlabel='N = %i' % len(sub_dataset.index), ylabel='')

    axes[0].set_ylabel(ylabel, size = 12)
    # label rotation based on length
    rotate_labels = False
    for axis in axes:
        for label in axis.get_xticklabels():
            if len(label.get_text())>15:
                rotate_labels = True
                break
    if rotate_labels:
        for axis in axes:
            axis.set_xticklabels(axis.get_xticklabels(),rotation=90)

    f.set_tight_layout({'rect':[0, 0, 1, 0.95]})
    if savepath is not None:
        plt.savefig(savepath, transparent=True)
    else:
        plt.show()
    plt.close()

def calculate_p_value(input_table, parameter, signature, mann_whithey_for_two_arrays = True, verbose = False):
    """
    Calculate p-value for a specified non-parametric test using the input dataframe and parameter name
    """
    arrays = {}
    # creating a dictionary of arrays for all possible parameter values
    for parameter_value in input_table[parameter].unique():
        dataset_for_parameter = input_table.loc[input_table[parameter] == parameter_value]
        arrays[parameter_value] = dataset_for_parameter[signature].values
        if verbose:
            print('Sample size for %s: %i' % (parameter_value, len(arrays[parameter_value])))
            print(arrays[parameter_value])
    list_of_arrays = [arrays[key] for key in arrays.keys()]
    if len(list_of_arrays)<2:
        # Not enough data: return -1
        if verbose:
            warnings.warn("Not enough variation for parameter %s" % parameter)
        return -1
    elif len(list_of_arrays) == 2:
        # Two arrays: perform Mann-Whitney test, or Welch t-test if mann_whithey_for_two_arrays = False
        if verbose:
            print("*"*25)
            print("%s: %s mutated gene vs no mutation in this gene" % (signature, parameter))
        if mann_whithey_for_two_arrays:
            p = calculate_mann_whitney_significance(list_of_arrays[0], list_of_arrays[1], verbose = verbose)
        else:
            p = calculate_welch_t_test_significance(list_of_arrays[0], list_of_arrays[1], verbose = verbose)
    else:
        # More than two arrays: perform Kruskal-Wallis test
        p = calculate_kruskal_wallis_significance(list_of_arrays, verbose = verbose)
    return p

def calculate_mann_whitney_significance(first_array, second_array, alternative = "two-sided", verbose = False):
    """
    Compute the Mann-Whitney-Wilcoxon rank test on input arrays.
    This is is a non-parametric test of the null hypothesis that it is equally
    likely that a randomly selected value from one sample will be less than
    or greater than a randomly selected value from a second sample.
    Does not require the assumption of normal distributions.

    Parameters:
    ----------
    first_array, second_array: array_like
        Arrays of samples, should be one-dimensional.

    alternative: ‘less’, ‘two-sided’, or ‘greater’
    Whether to get the p-value for the one-sided hypothesis (‘less’ or ‘greater’)
    or for the two-sided hypothesis (‘two-sided’). Defaults to ‘two-sided’.

    Returns:
    -------
    The p-value of the hypothesis test
    -------
    """
    try:
        stat, p = mannwhitneyu(first_array, second_array, alternative = alternative)
        if verbose:
            print('Mann–Whitney U test Statistics=%f, p=%e' % (stat, p))
    except:
        p = math.inf
        if verbose:
            print('Mann–Whitney U test failed (check input arrays)')

    return p

def calculate_welch_t_test_significance(first_array, second_array, verbose = False):
    """
    Calculate the Welch T-test for the means of two independent samples of scores.
    This is a two-sided test for the null hypothesis that 2 independent samples
    have identical average (expected) values. This test assumes that the populations
    have non-identical variances by default (as opposed to scipy.stats default settings).
    Requires the assumption of normal distributions.

    Parameters:
    ----------
    first_array, second_array: array_like
        Arrays of samples, should be one-dimensional.

    Returns:
    -------
    The two-tailed p-value of the hypothesis test
    -------
    """
    try:
        stat, p = ttest_ind(first_array, second_array, equal_var = False)
        if verbose:
            print('Welch t-test Statistics=%f, p=%e' % (stat, p))
    except:
        p = math.inf
        if verbose:
            print('Welch t-test failed (check input arrays)')
    return p

def calculate_kruskal_wallis_significance(list_of_arrays, verbose = False):
    """
    Calculate the Kruskal-Wallis H-test for two or more independent samples.
    This is an extension of the Mann-Whithey U test. The null hypothesis is
    is that the medians of all underlying distributions are equal.
    Does not require the assumption of normal distributions.

    Parameters:
    ----------
    list_of_arrays: list of array_like
        List of arrays of one-dimensional samples.

    Returns:
    -------
    The p-value of the hypothesis test
    -------
    """
    try:
        stat, p = kruskalwallis(*list_of_arrays)
        if verbose:
            print('Kruskal-Wallis H-test Statistics=%f, p=%e' % (stat, p))
    except:
        p = math.inf
        if verbose:
            print('Kruskal-Wallis H-test failed (check input arrays)')

    return p
